PEM survival model with random-walk baseline hazardΒΆ

import random
import survivalstan
import numpy as np
import pandas as pd
from stancache import stancache
from matplotlib import pyplot as plt
INFO:stancache.seed:Setting seed to 1245502385
model_code = survivalstan.models.pem_survival_model_randomwalk
/*  Variable naming:
 // dimensions
 N          = total number of observations (length of data)
 S          = number of sample ids
 T          = max timepoint (number of timepoint ids)
 M          = number of covariates

 // main data matrix (per observed timepoint*record)
 s          = sample id for each obs
 t          = timepoint id for each obs
 event      = integer indicating if there was an event at time t for sample s
 x          = matrix of real-valued covariates at time t for sample n [N, X]

 // timepoint-specific data (per timepoint, ordered by timepoint id)
 t_obs      = observed time since origin for each timepoint id (end of period)
 t_dur      = duration of each timepoint period (first diff of t_obs)

// Jacqueline Buros Novik <>

data {
  // dimensions
  int<lower=1> N;
  int<lower=1> S;
  int<lower=1> T;
  int<lower=0> M;

  // data matrix
  int<lower=1, upper=N> s[N];     // sample id
  int<lower=1, upper=T> t[N];     // timepoint id
  int<lower=0, upper=1> event[N]; // 1: event, 0:censor
  matrix[N, M] x;                 // explanatory vars

  // timepoint data
  vector<lower=0>[T] t_obs;
  vector<lower=0>[T] t_dur;
transformed data {
  vector[T] log_t_dur;  // log-duration for each timepoint
  int n_trans[S, T];

  log_t_dur = log(t_obs);

  // n_trans used to map each sample*timepoint to n (used in gen quantities)
  // map each patient/timepoint combination to n values
  for (n in 1:N) {
      n_trans[s[n], t[n]] = n;

  // fill in missing values with n for max t for that patient
  // ie assume "last observed" state applies forward (may be problematic for TVC)
  // this allows us to predict failure times >= observed survival times
  for (samp in 1:S) {
      int last_value;
      last_value = 0;
      for (tp in 1:T) {
          // manual says ints are initialized to neg values
          // so <=0 is a shorthand for "unassigned"
          if (n_trans[samp, tp] <= 0 && last_value != 0) {
              n_trans[samp, tp] = last_value;
          } else {
              last_value = n_trans[samp, tp];
parameters {
  vector[T] log_baseline_raw; // unstructured baseline hazard for each timepoint t
  vector[M] beta;                      // beta for each covariate
  real<lower=0> baseline_sigma;
  real log_baseline_mu;
transformed parameters {
  vector[N] log_hazard;
  vector[T] log_baseline;

  log_baseline = log_baseline_raw + log_t_dur;

  for (n in 1:N) {
    log_hazard[n] = log_baseline_mu + log_baseline[t[n]] + x[n,]*beta;
model {
  beta ~ cauchy(0, 2);
  event ~ poisson_log(log_hazard);
  log_baseline_mu ~ normal(0, 1);
  baseline_sigma ~ normal(0, 1);
  log_baseline_raw[1] ~ normal(0, 1);
  for (i in 2:T) {
      log_baseline_raw[i] ~ normal(log_baseline_raw[i-1], baseline_sigma);
generated quantities {
  real log_lik[N];
  vector[T] baseline;
  int y_hat_mat[S, T];     // ppcheck for each S*T combination
  real y_hat_time[S];      // predicted failure time for each sample
  int y_hat_event[S];      // predicted event (0:censor, 1:event)

  // compute raw baseline hazard, for summary/plotting
  baseline = exp(log_baseline_raw);

  for (n in 1:N) {
      log_lik[n] <- poisson_log_lpmf(event[n] | log_hazard[n]);

  // posterior predicted values
  for (samp in 1:S) {
      int sample_alive;
      sample_alive = 1;
      for (tp in 1:T) {
        if (sample_alive == 1) {
              int n;
              int pred_y;
              real log_haz;

              // determine predicted value of y
              // (need to recalc so that carried-forward data use sim tp and not t[n])
              n = n_trans[samp, tp];
              log_haz = log_baseline_mu + log_baseline[tp] + x[n,]*beta;
              if (log_haz < log(pow(2, 30)))
                  pred_y = poisson_log_rng(log_haz);
                  pred_y = 9;

              // mark this patient as ineligible for future tps
              // note: deliberately make 9s ineligible
              if (pred_y >= 1) {
                  sample_alive = 0;
                  y_hat_time[samp] = t_obs[tp];
                  y_hat_event[samp] = 1;

              // save predicted value of y to matrix
              y_hat_mat[samp, tp] = pred_y;
          else if (sample_alive == 0) {
              y_hat_mat[samp, tp] = 9;
      } // end per-timepoint loop

      // if patient still alive at max
      if (sample_alive == 1) {
          y_hat_time[samp] = t_obs[T];
          y_hat_event[samp] = 0;
  } // end per-sample loop

d = stancache.cached(
    rate_form='1 + sex',
    rate_coefs=[-3, 0.5],
d['age_centered'] = d['age'] - d['age'].mean()
INFO:stancache.stancache:sim_data_exp_correlated: cache_filename set to sim_data_exp_correlated.cached.N_100.censor_time_20.rate_coefs_54462717316.rate_form_1 + sex.pkl
INFO:stancache.stancache:sim_data_exp_correlated: Loading result from cache
INFO:stancache.stancache:sim_data_exp_correlated: Loading result from cache
age sex rate true_t t event index age_centered
0 59 male 0.082085 20.948771 20.000000 False 0 4.18
1 58 male 0.082085 12.827519 12.827519 True 1 3.18
2 61 female 0.049787 27.018886 20.000000 False 2 6.18
3 57 female 0.049787 62.220296 20.000000 False 3 2.18
4 55 male 0.082085 10.462045 10.462045 True 4 0.18
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='female'], event_col='event', time_col='t', label='female')
survivalstan.utils.plot_observed_survival(df=d[d['sex']=='male'], event_col='event', time_col='t', label='male')
<matplotlib.legend.Legend at 0x7f5317b03cf8>
dlong = stancache.cached(
    df=d, event_col='event', time_col='t'
INFO:stancache.stancache:prep_data_long_surv: cache_filename set to prep_data_long_surv.cached.df_33772694934.event_col_event.time_col_t.pkl
INFO:stancache.stancache:prep_data_long_surv: Loading result from cache
age sex rate true_t t event index age_centered key end_time end_failure
0 59 male 0.082085 20.948771 20.0 False 0 4.18 1 20.000000 False
1 59 male 0.082085 20.948771 20.0 False 0 4.18 1 12.827519 False
2 59 male 0.082085 20.948771 20.0 False 0 4.18 1 10.462045 False
3 59 male 0.082085 20.948771 20.0 False 0 4.18 1 0.196923 False
4 59 male 0.082085 20.948771 20.0 False 0 4.18 1 9.244121 False
testfit = survivalstan.fit_stan_survival_model(
    model_cohort = 'test model',
    model_code = model_code,
    df = dlong,
    sample_col = 'index',
    timepoint_end_col = 'end_time',
    event_col = 'end_failure',
    formula = '~ age_centered + sex',
    iter = 5000,
    chains = 4,
    seed = 9001,
    FIT_FUN = stancache.cached_stan_fit,

INFO:stancache.stancache:Step 1: Get compiled model code, possibly from cache
INFO:stancache.stancache:StanModel: cache_filename set to anon_model.cython_0_25_1.model_code_15125303112.pystan_2_12_0_0.stanmodel.pkl
INFO:stancache.stancache:StanModel: Loading result from cache
INFO:stancache.stancache:Step 2: Get posterior draws from model, possibly from cache
INFO:stancache.stancache:sampling: cache_filename set to anon_model.cython_0_25_1.model_code_15125303112.pystan_2_12_0_0.stanfit.chains_4.data_89490385305.iter_5000.seed_9001.pkl
INFO:stancache.stancache:sampling: Starting execution
INFO:stancache.stancache:sampling: Execution completed (0:11:49.722646 elapsed)
INFO:stancache.stancache:sampling: Saving results to cache
survivalstan.utils.print_stan_summary([testfit], pars='lp__')
            mean   se_mean         sd        2.5%         50%       97.5%      Rhat
lp__ -299.460663  1.927055  26.352078 -347.240375 -301.398746 -241.969884  1.024231
survivalstan.utils.plot_stan_summary([testfit], pars='log_baseline_raw')
survivalstan.utils.plot_coefs([testfit], element='baseline')
survivalstan.utils.plot_pp_survival([testfit], fill=False)
<matplotlib.legend.Legend at 0x7f523338dc50>
survivalstan.utils.plot_pp_survival([testfit], by='sex')
